package main

import (
	"fmt"
	"os"
	"os/exec"
	"path"
	"text/template"

	"github.com/gomlx/gomlx/internal/backendparser"
	"github.com/gomlx/gomlx/internal/must"
	"github.com/gomlx/gomlx/pkg/support/sets"
	"k8s.io/klog/v2"
)

const (
	binaryOpsFile = "gen_binary_ops.go"
)

func IsBinaryOp(method backendparser.Method) bool {
	if len(method.Parameters) != 2 {
		return false
	}
	if method.Parameters[0].Type != "Op" || method.Parameters[1].Type != "Op" {
		return false
	}
	if BinaryOpsToExclude.Has(method.Name) {
		return false
	}
	if len(method.Outputs) != 2 || method.Outputs[0].Type != "Op" || method.Outputs[1].Type != "error" {
		return false
	}
	return true
}

var (
	// BinaryOpAliases map GoMLX backend binary ops to StableHLO equivalent ops.
	BinaryOpAliases = map[string]string{
		"Mul": "Multiply",
		"Div": "Divide",
		"Max": "Maximum",
		"Min": "Minimum",
		"Pow": "Power",
		"Rem": "Remainder",
		"Sub": "Subtract",

		"BitwiseAnd": "And",
		"BitwiseOr":  "Or",
		"BitwiseXor": "Xor",
		"LogicalAnd": "And",
		"LogicalOr":  "Or",
		"LogicalXor": "Xor",
	}

	BinaryOpsToExclude = sets.MakeWith("Dot",
		// Comparison ops.
		"Equal", "GreaterThan", "GreaterOrEqual", "LessThan", "LessOrEqual", "NotEqual",
		"EqualTotalOrder", "GreaterThanTotalOrder", "GreaterOrEqualTotalOrder", "LessThanTotalOrder", "LessOrEqualTotalOrder", "NotEqualTotalOrder",
	)

	binaryOpsTemplate = template.Must(template.New(binaryOpsFile).Parse(`
/***** File generated by ./internal/cmd/stable_generator, based on the backends' ops. Don't edit it directly. *****/

package stablehlo

import (
	"github.com/gomlx/gomlx/backends"
	"github.com/gomlx/stablehlo"
)

{{- range .}}
{{- range .Method.Comments}}
{{.}}
{{- end}}
func (b *Builder) {{.Method.Name}}(lhs, rhs backends.Op) (backends.Op, error) {
	lhsNode, rhsNode, err := b.broadcastForBinaryOps(backends.OpType{{.Method.Name}}, lhs, rhs)
	if err != nil {
		return nil, err
	}
	value, err := stablehlo.{{.Alias}}(lhsNode.value, rhsNode.value)
	if err != nil {
		return nil, err
	}
	return b.newNode(value), nil
}
{{end}}
`))
)

func GenerateBinaryOps(methods []backendparser.Method) {
	var data []struct {
		Method backendparser.Method
		Alias  string
	}
	for _, method := range methods {
		if !IsBinaryOp(method) {
			continue
		}
		alias := method.Name
		if a, ok := BinaryOpAliases[method.Name]; ok {
			alias = a
		}
		data = append(data, struct {
			Method backendparser.Method
			Alias  string
		}{method, alias})
	}

	fileName := path.Join(must.M1(os.Getwd()), binaryOpsFile)
	f := must.M1(os.Create(fileName))
	must.M(binaryOpsTemplate.Execute(f, data))
	cmd := exec.Command("go", "fmt", fileName)
	klog.V(1).Infof("\t%s\n", cmd)
	cmd.Stderr = os.Stderr
	must.M(cmd.Run())
	fmt.Printf("✅ stablehlo_generator:  \tsuccessfully generated %s\n", fileName)
}
